Skip to content

Make diffusion model conditioning more flexible #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 24, 2025

Conversation

arrjon
Copy link
Member

@arrjon arrjon commented Jun 24, 2025

I introduced a new keyword concatenated_input to the subnet_kwargs in the diffusion model. This keyword controls how inputs—such as parameters, noise, and condition—are fed into the model’s subnet.

Previously, the model assumed all inputs were 1D vectors and concatenated them directly for the default MLP subnet. However, for more flexible architectures—such as subnets designed to preserve or induce spatial structures—this assumption doesn't hold. Now we can also return all inputs separately directly to the subnet.

@arrjon arrjon requested review from stefanradev93 and vpratz June 24, 2025 13:41
@arrjon arrjon self-assigned this Jun 24, 2025
Copy link

codecov bot commented Jun 24, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
...w/networks/consistency_models/consistency_model.py 97.79% <100.00%> (+0.13%) ⬆️
...esflow/networks/diffusion_model/diffusion_model.py 79.79% <100.00%> (+0.64%) ⬆️
bayesflow/networks/flow_matching/flow_matching.py 93.44% <100.00%> (+0.58%) ⬆️

@vpratz
Copy link
Collaborator

vpratz commented Jun 24, 2025

Thanks for the PR, I think this is a reasonable idea for advanced use cases. As this is another instance of multi-input networks (even though it's inside the inference network this time, and does not involve the adapter), we might want to include this in the discussion in #517. We might also think about:

  • how to pass inputs: named (as a dictionary), or via position (as a tuple like you propose in the PR)
  • Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

Tagging @LarsKue for comment as well.

@stefanradev93
Copy link
Contributor

stefanradev93 commented Jun 24, 2025

One very general approach would be to break free from the fixed names, such as "inference_variables", and actually allow for:

  • Marking different simulator outputs as either target variables, summary variables or inference conditions
  • Selecting a strategy for how different outputs of the same type are handled (e.g., concatenated, passed as a tuple, or passed as keyword arguments)
    This can be handled with another abstraction, such as SimulatorOutput with a flexible scheme.

@arrjon
Copy link
Member Author

arrjon commented Jun 24, 2025

Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

I agree @vpratz. I added the same logic to both models. I am open to suggestions how the tensors should be passed to the subnet.

To @stefanradev93 comment: I think, this flexibility is only needed for advanced users. So maybe we should not follow this general approach for now, as the fixed names help users to get started with BayesFlow.

@stefanradev93
Copy link
Contributor

Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

I agree @vpratz. I added the same logic to both models. I am open to suggestions how the tensors should be passed to the subnet.

To @stefanradev93 comment: I think, this flexibility is only needed for advanced users. So maybe we should not follow this general approach for now, as the fixed names help users to get started with BayesFlow.

Absolutely, this is definitely a 2.>1.x idea.

@arrjon
Copy link
Member Author

arrjon commented Jul 8, 2025

For now, it is okay as it is @stefanradev93?

@vpratz
Copy link
Collaborator

vpratz commented Jul 8, 2025

I think we need to document somewhere how this can be used (i.e., which inputs are passed to the network if concatenate_subnet_input is False`), as this currently only exists in the code itself. I'd suggest we pass the inputs by name, as this eases up communication (only names, no order).

It would be good to have a test in place for the concatenate_subnet_input is False case.

@arrjon
Copy link
Member Author

arrjon commented Jul 9, 2025

Thanks @vpratz for the suggestions. I added the documentation.

Regarding the test: As we do not have a network at the moment which can handle multiple inputs, I do not know a useful test for the concatenate_subnet_input=False case. Any suggestions?

@vpratz
Copy link
Collaborator

vpratz commented Jul 11, 2025

@arrjon Thanks for the changes!
I would propose to add a simple wrapper network that can handle the case to the test suite, similar to other dummy networks we have. It could just take the separate inputs, concatenate them and pass them to any other network as usual. The main point for me is that we are able to notice if we accidentally break the functionality, so a basic dedicated test should be sufficient.

@arrjon
Copy link
Member Author

arrjon commented Jul 17, 2025

@vpratz I added the test. Now the merge should be ready! :)

@vpratz
Copy link
Collaborator

vpratz commented Jul 22, 2025

Thanks a lot @arrjon . Sorry for being slow to review, and to always coming up with new things I didn't notice before.

The build functions do not take the concatenate_subnet_input parameter into account when building the subnet. I'm not sure in when this is problematic and when not, but I think it would be good to pass the correct shapes there as well, to avoid weird problems down the line.

@arrjon
Copy link
Member Author

arrjon commented Jul 22, 2025

I corrected the shape for the build functions. Please verify the implementation, locally the tests passed, but I am not too confident with the implementation.

@arrjon
Copy link
Member Author

arrjon commented Jul 23, 2025

@vpratz now also the tests passed :)

@vpratz
Copy link
Collaborator

vpratz commented Jul 23, 2025

Great, I'll take a final look tomorrow and then merge it.

vpratz added 2 commits July 24, 2025 06:36
The convention is to use parameter name with a `_shape` suffix.
The cost of the continuous models on the CI is too high
@vpratz
Copy link
Collaborator

vpratz commented Jul 24, 2025

I have changed the naming for the shapes, and created a reduced test for testing the setting, as the inference network tests are really slow.

@vpratz vpratz merged commit d9e9782 into dev Jul 24, 2025
9 checks passed
@vpratz vpratz deleted the diffusion-model-conditioning branch July 24, 2025 09:14
@arrjon
Copy link
Member Author

arrjon commented Jul 24, 2025

Thanks @vpratz!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants